{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ExternalSource operator\n",
    "\n",
    "In this example, we will see how to use `ExternalSource` operator with PyTorch DALI iterator, that allows us to\n",
    "use an external data source as an input to the Pipeline.\n",
    "\n",
    "In order to achieve that, we have to define a Iterator or Generator class which `next` function will\n",
    "return one or several `numpy` arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import types\n",
    "import collections\n",
    "import numpy as np\n",
    "from random import shuffle\n",
    "from nvidia.dali.pipeline import Pipeline\n",
    "import nvidia.dali.types as types\n",
    "import nvidia.dali.fn as fn\n",
    "import torch\n",
    "\n",
    "batch_size = 3\n",
    "epochs = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the Iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExternalInputIterator(object):\n",
    "    def __init__(self, batch_size, device_id, num_gpus):\n",
    "        self.images_dir = \"../../data/images/\"\n",
    "        self.batch_size = batch_size\n",
    "        with open(self.images_dir + \"file_list.txt\", \"r\") as f:\n",
    "            self.files = [line.rstrip() for line in f if line is not \"\"]\n",
    "        # whole data set size\n",
    "        self.data_set_len = len(self.files)\n",
    "        # based on the device_id and total number of GPUs - world size\n",
    "        # get proper shard\n",
    "        self.files = self.files[\n",
    "            self.data_set_len\n",
    "            * device_id\n",
    "            // num_gpus : self.data_set_len\n",
    "            * (device_id + 1)\n",
    "            // num_gpus\n",
    "        ]\n",
    "        self.n = len(self.files)\n",
    "\n",
    "    def __iter__(self):\n",
    "        self.i = 0\n",
    "        shuffle(self.files)\n",
    "        return self\n",
    "\n",
    "    def __next__(self):\n",
    "        batch = []\n",
    "        labels = []\n",
    "\n",
    "        if self.i >= self.n:\n",
    "            self.__iter__()\n",
    "            raise StopIteration\n",
    "\n",
    "        for _ in range(self.batch_size):\n",
    "            jpeg_filename, label = self.files[self.i % self.n].split(\" \")\n",
    "            batch.append(\n",
    "                np.fromfile(self.images_dir + jpeg_filename, dtype=np.uint8)\n",
    "            )  # we can use numpy\n",
    "            labels.append(\n",
    "                torch.tensor([int(label)], dtype=torch.uint8)\n",
    "            )  # or PyTorch's native tensors\n",
    "            self.i += 1\n",
    "        return (batch, labels)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data_set_len\n",
    "\n",
    "    next = __next__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the Pipeline\n",
    "\n",
    "Now let's define our pipeline. We need an instance of ``Pipeline`` class and some operators which will define the processing graph. Our external source provides 2 outpus which we can conveniently unpack by specifying ``num_outputs=2`` in the external source operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ExternalSourcePipeline(batch_size, num_threads, device_id, external_data):\n",
    "    pipe = Pipeline(batch_size, num_threads, device_id)\n",
    "    with pipe:\n",
    "        jpegs, labels = fn.external_source(\n",
    "            source=external_data, num_outputs=2, dtype=types.UINT8\n",
    "        )\n",
    "        images = fn.decoders.image(jpegs, device=\"mixed\")\n",
    "        images = fn.resize(images, resize_x=240, resize_y=240)\n",
    "        output = fn.cast(images, dtype=types.UINT8)\n",
    "        pipe.set_outputs(output, labels)\n",
    "    return pipe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the Pipeline\n",
    "\n",
    "In the end, let us see how it works.\n",
    "\n",
    "`last_batch_padded` and `last_batch_policy` are set here only for the demonstration purposes. The user may write any custom code and change the epoch size epoch to epoch. In that case, it is recommended to set `size` to -1 and let the iterator just wait for StopIteration exception from the `iter_setup`.\n",
    "\n",
    "The `last_batch_padded` here tells the iterator that the difference between data set size and batch size alignment is padded by real data that could be skipped when provided to the framework (`last_batch_policy`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, iter 0, real batch size: 3\n",
      "epoch: 0, iter 1, real batch size: 3\n",
      "epoch: 0, iter 2, real batch size: 3\n",
      "epoch: 0, iter 3, real batch size: 3\n",
      "epoch: 0, iter 4, real batch size: 3\n",
      "epoch: 0, iter 5, real batch size: 3\n",
      "epoch: 0, iter 6, real batch size: 3\n",
      "epoch: 1, iter 0, real batch size: 3\n",
      "epoch: 1, iter 1, real batch size: 3\n",
      "epoch: 1, iter 2, real batch size: 3\n",
      "epoch: 1, iter 3, real batch size: 3\n",
      "epoch: 1, iter 4, real batch size: 3\n",
      "epoch: 1, iter 5, real batch size: 3\n",
      "epoch: 1, iter 6, real batch size: 3\n",
      "epoch: 2, iter 0, real batch size: 3\n",
      "epoch: 2, iter 1, real batch size: 3\n",
      "epoch: 2, iter 2, real batch size: 3\n",
      "epoch: 2, iter 3, real batch size: 3\n",
      "epoch: 2, iter 4, real batch size: 3\n",
      "epoch: 2, iter 5, real batch size: 3\n",
      "epoch: 2, iter 6, real batch size: 3\n"
     ]
    }
   ],
   "source": [
    "from nvidia.dali.plugin.pytorch import (\n",
    "    DALIClassificationIterator as PyTorchIterator,\n",
    ")\n",
    "from nvidia.dali.plugin.pytorch import LastBatchPolicy\n",
    "\n",
    "eii = ExternalInputIterator(batch_size, 0, 1)\n",
    "pipe = ExternalSourcePipeline(\n",
    "    batch_size=batch_size, num_threads=2, device_id=0, external_data=eii\n",
    ")\n",
    "pii = PyTorchIterator(\n",
    "    pipe, last_batch_padded=True, last_batch_policy=LastBatchPolicy.PARTIAL\n",
    ")\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, data in enumerate(pii):\n",
    "        real_batch_size = len(data[0][\"data\"])\n",
    "        print(f\"epoch: {e}, iter {i}, real batch size: {real_batch_size}\")\n",
    "    pii.reset()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
